Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/hiera #418

Merged
merged 31 commits into from
Aug 26, 2024
Merged

Feature/hiera #418

merged 31 commits into from
Aug 26, 2024

Conversation

benjijamorris
Copy link
Contributor

@benjijamorris benjijamorris commented Aug 19, 2024

What does this PR do?

  • add 2d and 3d Hiera Models
  • refactor mae/hiera to encoder/decoder.py instead of model-specific files
  • unify patchify code
  • add tests
  • updates to JEPA to accommodate refactoring

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

Did you have fun?

Make sure you had fun coding 🙃

ritvikvasan
ritvikvasan previously approved these changes Aug 21, 2024
Copy link
Member

@ritvikvasan ritvikvasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a lot!! Loved reading it...As always, minor comments

backbone:
_target_: cyto_dl.nn.vits.mae.HieraMAE
spatial_dims: ${spatial_dims}
patch_size: 2 # patch_size* num_patches should be your patch shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be image shape?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be a list for ZYX patch size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes - the terminology is confusing here haha. "patch" = the small crop extracted from your original image, but "patch" is also the tokenized component of the image fed into the network. The patch size can be either an int (repeated for each spatial dim) or a list of size spatial_dims

spatial_dims: ${spatial_dims}
patch_size: 2 # patch_size* num_patches should be your patch shape
num_patches: 8 # patch_size * num_patches = img_shape
num_mask_units: 4 #img_shape / num_mask_units = size of each mask unit in pixels, num_patches/num_mask_units = number of patches permask unit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarify what a mask unit is here?

architecture:
# mask_unit_attention blocks - attention is only done within a mask unit and not across mask units
# the total amount of q_stride across the architecture must be less than the number of patches per mask unit
- repeat: 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is repeat?

# self attention transformer - attention is done across all patches, irrespective of which mask unit they're in
- repeat: 2
num_heads: 4
self_attention: True
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so last layer is global attention and first 2 layers are local attention? Is 3 layers the recommended hierarchy?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct. 3 layers is small enough to test quickly. All of the models with unit tests are tiny by default in the configs and I have somewhere in the docs that you should increase the model size if you want good performance.

if self.spatial_dims == 3:
q = reduce(
q,
"b n h (n_patches_z q_stride_z n_patches_y q_stride_y n_patches_x q_stride_x) c ->b n h (n_patches_z n_patches_y n_patches_x) c",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use the same nomenclature here? e.g. n = num_mask_units = mask_units, num_heads = h = n_heads

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c = head_dim

self.spatial_dims = spatial_dims
self.num_heads = num_heads
self.head_dim = dim_out // num_heads
self.scale = qk_scale or self.head_dim**-0.5
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this isn't used anywhere

# change dimension and subsample within mask unit for skip connection
x = self.proj(x_norm)

x = x + self.drop_path(self.attn(x_norm))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does dim_out = dim for skip connection with attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question - each block specified in the architecture argument doubles the embedding dimension and halves the size of the mask unit. This doubling/pooling happens on the last repeat of the block, so dim_out=dim for all repeats except the last. I updated the docstring with an example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool!

dim_out: int,
heads: int,
spatial_dims: int = 3,
mlp_ratio: float = 4.0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is mlp_ratio? add to docstring?



class PatchifyHiera(PatchifyBase):
"""Class for converting images to a masked sequence of patches with positional embeddings."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to "mask units" instead of masked sequence? since that's what a regular patchify does?

@@ -40,3 +47,8 @@ def get_positional_embedding(
cls_token = torch.zeros(1, 1, emb_dim)
pe = torch.cat([cls_token, pe], dim=0)
return torch.nn.Parameter(pe, requires_grad=False)


def validate_spatial_dims(spatial_dims, tuples):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like the code might be clearer by not having this be a separate function and just calling these 2 lines in every class? I thought this function was doing a lot more based on the name (like some math to check that the spatial dimensions of each patch and mask is correct). What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer renaming it to something clearer rather than repeating the code, maybe match_tuple_to_spatial_dims?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good

@ritvikvasan
Copy link
Member

windows tests always seem to be failing. any ideas why?

@benjijamorris
Copy link
Contributor Author

the windows tests are just way slower for some reason... all the tests pass but I set a 70 minute time out so we don't rack up crazy costs.

ritvikvasan
ritvikvasan previously approved these changes Aug 23, 2024
@benjijamorris benjijamorris merged commit 48e7cb9 into main Aug 26, 2024
4 of 6 checks passed
@benjijamorris benjijamorris deleted the feature/hiera branch August 26, 2024 17:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants